package com.cyy.model

import java.util.ArrayList
import java.sql.ResultSet
import java.sql.DriverManager
import java.io.FileOutputStream
import com.jfinal.kit.Kv
import com.jfinal.template.Engine
import com.jfinal.template.source.ClassPathSourceFactory
import java.util.LinkedHashMap
import org.apache.log4j.Logger;
import java.io.File
import java.sql.Connection
import java.sql.DatabaseMetaData

/**
 * 表信息
 */
data class TableMeta(var dbType:String,
                     var dbName:String,
                     var tableName:String,
                     var primaryKey:String,
                     var uniqueKeys: List<String> = ArrayList(),
                     var indexKeys: List<String> = ArrayList(),
                     var remark :String,
                     var columns: List<ColumnMeta> = ArrayList())

/**
 * 列信息
 */
data class ColumnMeta(var tableName:String="",
                     var columnName:String="",
                     var columnType:String="",
                     var columnSize:String="",
                     var dataType: String="",
                     var defaultValue: String="",
                      var remark :String="",
                      var nullable :String="YES",
                     var primaryKey: Boolean=false)

/**
 * 获取和生成表结构
 *
 * <font color="red">注：如果需要跳过某些表，请修改isSkip(tableName)方法。</font>
 * @author Admin
 */
class DbBuilder
/**
 *
 * @param url 数据库连接地址
 * @param user 数据库用户
 * @param password 数据库密码
 */
(url: String, user: String, password: String) {

    private val logger = Logger.getLogger(javaClass)
    // 数据库类型
    private var dbType = ""
    private var dbName = ""
    private var driver = ""
    private var url = ""
    private var user = ""
    private var password = ""

    /**
     * 获取所有表结构信息
     * @return
     */
    val tableMetaList: List<TableMeta>
        get() = build()

    init {
        var url = url
        if (url.startsWith("jdbc:mysql")) {
            this.dbType = "mysql"
            val str = url.substring(0, url.indexOf("?"))
            this.dbName = str.substring(str.lastIndexOf("/"))
            this.driver = "com.mysql.jdbc.Driver"
        } else if (url.startsWith("jdbc:oracle")) {
            this.dbType = "oracle"
            this.driver = "oracle.jdbc.OracleDriver"
        } else if (url.startsWith("jdbc:sqlserver")) {
            this.dbType = "sqlserver"
            this.driver = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
        } else if (url.startsWith("jdbc:postgresql")) {
            this.dbType = "postgresql"
            this.driver = "org.postgresql.Driver"
        } else if (url.startsWith("jdbc:db2")) {
            this.dbType = "db2"
            this.driver = "com.ibm.db2.jcc.DB2Driver"
        } else if (url.startsWith("jdbc:hsqldb")) {
            this.dbType = "hsqldb"
            this.driver = "org.hsqldb.jdbcDriver"
        } else if (url.startsWith("jdbc:derby")) {
            this.dbType = "derby"
            this.driver = "org.apache.derby.jdbc.ClientDriver"
        }
        if ("mysql" == dbType && !url.contains("useInformationSchema")) {
            // 如果为MySQL，url连接后面需要要添加useInformationSchema=true参数，否则表、列注释无法获取
            url += (if (url.contains("?")) "&" else "?") + "useInformationSchema=true"
        }
        this.url = url
        this.user = user
        this.password = password
    }

    /**
     * 判断表是否需要跳过
     * @param tableName
     * @return
     */
    protected fun isSkip(tableName: String): Boolean {
        return false
    }

    /**
     * 单个数据库表结构输出到文件
     * @param path
     */
    fun render(path: File) {
        val dbTableMetas = build()
        val engine = Engine.use().setSourceFactory(ClassPathSourceFactory())
        val template = engine.getTemplate("generator/single_db-template.html")//batch_db-template.html
        val html = template.renderToString(Kv.by("tables", dbTableMetas))
        try {
            FileOutputStream(path).use({ out -> out.write(html.toByteArray(charset("UTF-8"))) })
        } catch (e: Exception) {
            logger.error("生成失败！", e)
        }

    }

    /**
     * 所有数据库表结构统一输出到一个文件
     * @param batchDbTableMetas
     * @param path
     */
    fun renderBatch(batchDbTableMetas: LinkedHashMap<String, List<TableMeta>>, path: File) {
        val engine = Engine.use().setSourceFactory(ClassPathSourceFactory())
        val template = engine.getTemplate("generator/batch_db-template.html")
        val html = template.renderToString(Kv.by("batchDbTables", batchDbTableMetas))
        try {
            FileOutputStream(path).use({ out -> out.write(html.toByteArray(charset("UTF-8"))) })
        } catch (e: Exception) {
            logger.error("生成失败！", e)
        }

    }

    /**
     * 获取表结构并构建
     * @return
     */
    private fun build(): List<TableMeta> {
        logger.info("Start to build TableMeta ...")
        val dbTableMetas = ArrayList<TableMeta>()
        try {
            val conn = connect()
            val dbMeta = conn!!.getMetaData()
            // 不支持 view 生成
            val rs = dbMeta.getTables(conn!!.getCatalog(), "", "%", arrayOf("TABLE"))

            while (rs.next()) {
                val tableName = rs.getString("TABLE_NAME")

                if (isSkip(tableName)) {
                    logger.info("Skip table：$tableName")
                    continue
                }
                val tableRemark = rs.getString("REMARKS")
                val primaryKey = getPrimaryKey(dbMeta, conn, tableName)
                val tableColumns = ArrayList<ColumnMeta>()
                getTableInfo(conn, dbMeta, tableName, primaryKey, tableColumns)

                val tableMeta = TableMeta(dbType = dbType,dbName = dbName,tableName = tableName,primaryKey = primaryKey,remark = tableRemark,columns = tableColumns)
                dbTableMetas.add(tableMeta)
            }
            close(rs, conn)
        } catch (e: Exception) {
            logger.error("Build TableMeta Exception!", e)
        }

        logger.info("Build TableMeta Finished!")
        return dbTableMetas
    }

    private fun connect(): Connection? {
        var conn: Connection? = null
        try {
            Class.forName(driver)
            conn = DriverManager.getConnection(url, user, password)
            if (!conn!!.isClosed()) {
                logger.info("Connect to database success!")
            }
        } catch (e: Exception) {
            logger.error("Sorry, Connect to database failed!", e)
        }

        return conn
    }

    /**
     * 获取主键
     * @param dbMeta
     * @param conn
     * @param tableName
     * @return
     */
    private fun getPrimaryKey(dbMeta: DatabaseMetaData, conn: Connection, tableName: String): String {
        var primaryKey = ""
        var rs: ResultSet? = null
        try {
            rs = dbMeta.getPrimaryKeys(conn.getCatalog(), null, tableName)
            var index = 0
            while (rs!!.next()) {
                if (index++ > 0) {
                    primaryKey += ","
                }
                primaryKey += rs!!.getString("COLUMN_NAME")
            }
        } catch (e: Exception) {
            e.printStackTrace()
        } finally {
            closeResultSet(rs)
        }
        return primaryKey
    }

    /**
     * 获取表结构信息
     * @param conn
     * @param tableName
     * @param columns
     */
    private fun getTableInfo(conn: Connection, dbMeta: DatabaseMetaData, tableName: String, primaryKeys: String?, columns: MutableList<ColumnMeta>) {
        var rs: ResultSet? = null
        try {
            rs = dbMeta.getColumns(conn.getCatalog(), null, tableName, null)
            while (rs!!.next()) {
                val columnMeta =ColumnMeta()
                val columnName = rs!!.getString("COLUMN_NAME")
                var columnType: String? = rs.getString("TYPE_NAME")
                var dataType = ""
                if (columnType == null) {
                    columnType = ""
                }
                val columnSize = rs.getInt("COLUMN_SIZE")

                if (columnSize > 0) {
                    dataType = "$columnType($columnSize"
                    val decimalDigits = rs.getInt("DECIMAL_DIGITS")// 小数位数
                    if (decimalDigits > 0) {
                        columnType = "$dataType,$decimalDigits"
                    }
                    dataType = "$dataType)"
                }

                if (primaryKeys != null && "" != primaryKeys) {
                    val keys = primaryKeys.split(",".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
                    for (key in keys) {
                        if (columnName.equals(key, ignoreCase = true)) {
                            columnMeta.primaryKey=true
                        }
                    }
                }
                val nullable = rs.getString("IS_NULLABLE")
                val defaultValue = rs.getString("COLUMN_DEF")
                val remark = rs.getString("REMARKS")

                columnMeta.tableName=tableName
                columnMeta.columnName=columnName
                columnMeta.columnType=columnType
                columnMeta.columnSize=columnSize.toString() + ""
                columnMeta.dataType=dataType
                columnMeta.nullable=nullable
                columnMeta.defaultValue=defaultValue ?: "NULL"
                columnMeta.remark=remark ?: ""
                columns.add(columnMeta)
            }
        } catch (e: Exception) {
            logger.error("获取表结构信息异常", e)
        } finally {
            closeResultSet(rs)
        }
    }

    private fun close(rs: ResultSet, conn: Connection) {
        closeResultSet(rs)
        closeConn(conn)
    }

    private fun closeResultSet(rs: ResultSet?) {
        try {
            if (rs != null && !rs.isClosed) {
                rs.close()
            }
        } catch (e: Exception) {
            e.printStackTrace()
        }

    }

    private fun closeConn(conn: Connection?) {
        try {
            if (conn != null && !conn!!.isClosed()) {
                conn!!.close()
            }
        } catch (e: Exception) {
            e.printStackTrace()
        }

    }
}

object Demo1 {

    private var ip = ""
    private var port = 3306
    private var user = ""
    private var password = ""
    private val url = "jdbc:mysql://%s:%s/%s?zeroDateTimeBehavior=convertToNull&useInformationSchema=true"

    @JvmStatic
    fun main(args: Array<String>) {
        ip = "127.0.0.1"
        port = 3306
        user = "root"
        password = "123456"

        val outputBasePath = "D:\\"
        val baseFile = File(outputBasePath)
        if (!baseFile.exists()) {
            baseFile.mkdirs()
        }

        try {
            val file = newFile(baseFile.path, "数据库表结构信息导出")
            if (file != null) {
                DbBuilder(String.format(url, ip, port, "shopxo"), user, password).render(file)
                println("生成成功!")
            } else {
                System.err.println("生成失败!")
            }
        } catch (e: Exception) {
            System.err.println("生成失败!")
            e.printStackTrace()
        }

    }

    private fun newFile(basePath: String, fileName: String): File? {
        try {
            val file = File(basePath + File.separator + fileName + ".html")
            if (!file.exists()) {
                file.createNewFile()
            }
            return file
        } catch (e: Exception) {
            e.printStackTrace()
        }

        return null
    }
}